# !/usr/bin/env python
# coding: utf-8

import numpy as np
from skmultilearn.problem_transform import ClassifierChain, BinaryRelevance
from sklearn.linear_model import LogisticRegression, RidgeClassifier
from sklearn.svm import SVC
import sklearn.metrics as metrics
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
import statsmodels.formula.api as smf
from typing import List
import statsmodels.api as sm
import seaborn as sns
import statsmodels.formula.api as smf
import warnings
warnings.filterwarnings("ignore")
np.random.seed(2)


def prepare_regressions(
        Y_tuple, Y_reg, reg_type='multiple'):
    """
    Function to run the auxiliar regression
    Returns the parameters (betas) for each regression
    as well as the errors and the coverage probability
    """
    y1, y2, y3, y4, y5 = Y_tuple
    Y_br, Y_cc = Y_reg
    n = Y_br.shape[0]
    noise_magnitude = 1

    y0 = np.ones(n)

    if reg_type == 'glm':
      
        d = pd.DataFrame(np.column_stack([y3, y4]), columns=['y4', 'y3'])
        true_glm = smf.glm(
            formula="y4 ~ y3",
            family=sm.families.Binomial(link=sm.genmod.families.links.logit),
            data=d
            ).fit()
        coefs = true_glm.params
        y0_true = -1 / (1+np.exp(-coefs[0]))
        y1_true = -1 / (1+np.exp(-(coefs[0]+coefs[1])))
        ydf = y1_true - y0_true

        columns_names = [
                        'Y1_br', 'Y2_br', 'Y3_br', 'Y4_br', 'Y5_br',
                        'Y1_cc', 'Y2_cc', 'Y3_cc', 'Y4_cc', 'Y5_cc']

        regression_df = pd.DataFrame(
                                    data=np.concatenate([Y_br, Y_cc], axis=1),
                                    columns=columns_names)
        
        br_glm = smf.glm(
            formula="Y4_br ~ Y3_br",
            family=sm.families.Binomial(link=sm.genmod.families.links.logit),
            data=regression_df
            ).fit()
        coefs = br_glm.params
        y0_br = -1 / (1+np.exp(-coefs[0]))
        y1_br = -1 / (1+np.exp(-(coefs[0]+coefs[1])))
        ydf_br = y1_br - y0_br

        cc_glm = smf.glm(
            formula="Y4_cc ~ Y3_cc",
            family=sm.families.Binomial(link=sm.genmod.families.links.logit),
            data=regression_df
            ).fit()
        coefs = cc_glm.params
        y0_cc = -1 / (1+np.exp(-coefs[0]))
        y1_cc = -1 / (1+np.exp(-(coefs[0]+coefs[1])))
        ydf_cc = y1_cc - y0_cc

        br_error = np.sqrt(np.square(br_glm.params-true_glm.params.values))
        br_bse = br_glm.bse
        z_br_error = np.sqrt(np.square(ydf_br - ydf))

        br_PC = (br_glm.params - 1.96*br_bse < true_glm.params.values) \
            & (br_glm.params + 1.96 * br_bse > true_glm.params.values)

        cc_error = np.sqrt(np.square(cc_glm.params-true_glm.params.values))
        cc_bse = cc_glm.bse
        z_cc_error = np.sqrt(np.square(ydf_cc - ydf))

        cc_PC = (cc_glm.params - 1.96 * cc_bse < true_glm.params.values) \
            & (cc_glm.params + 1.96 * cc_bse > true_glm.params.values)

        return br_error, br_bse, z_br_error, cc_error, cc_bse,\
            z_cc_error, br_glm.params, cc_glm.params, br_PC*1, cc_PC*1

    elif reg_type == 'multiple':
        Y_reg = np.column_stack([y0, y4, y2, y3])
        Beta_reg = np.ones((Y_reg.shape[1], 1))
        Z = np.dot(Y_reg, Beta_reg) + np.random.randn(n, 1)*noise_magnitude

        columns_names = [
            'Z', 'Y1_br', 'Y2_br', 'Y3_br', 'Y4_br', 'Y5_br',
            'Y1_cc', 'Y2_cc', 'Y3_cc', 'Y4_cc', 'Y5_cc']

        regression_df = pd.DataFrame(
                                    data=np.concatenate([Z, Y_br, Y_cc],
                                                        axis=1),
                                    columns=columns_names)
        br_specification = 'Z ~  Y4_br + Y2_br + Y3_br'
        cc_specification = 'Z ~  Y4_cc + Y2_cc + Y3_cc'

    elif reg_type == 'single':
        Y_reg = np.column_stack([y0, y4])
        br_specification = 'Z ~  Y4_br'
        cc_specification = 'Z ~  Y4_cc'
        Beta_reg = np.ones((Y_reg.shape[1], 1))
        Z = np.dot(Y_reg, Beta_reg) + np.random.randn(n, 1)*noise_magnitude

        columns_names = [
            'Z', 'Y1_br', 'Y2_br', 'Y3_br', 'Y4_br', 'Y5_br',
            'Y1_cc', 'Y2_cc', 'Y3_cc', 'Y4_cc', 'Y5_cc']

        regression_df = pd.DataFrame(
                                    data=np.concatenate([Z, Y_br, Y_cc],
                                                        axis=1),
                                    columns=columns_names)

    #  Fitting the regressions

    br_reg = smf.ols(br_specification, data=regression_df).fit()
    cc_reg = smf.ols(cc_specification, data=regression_df).fit()

    #  BR

    br_error = np.sqrt(np.square(br_reg.params-Beta_reg.flatten()))
    br_bse = br_reg.bse
    br_PC = (br_reg.params - 1.96 * br_bse < Beta_reg.flatten()) & \
            (br_reg.params + 1.96 * br_bse > Beta_reg.flatten())
    # standard error
    z_br_error = np.sqrt(np.sum(np.square(br_reg.predict() - Z)))/n

    #  CC
    cc_error = np.sqrt(np.square(cc_reg.params-Beta_reg.flatten()))
    # standard error
    cc_bse = cc_reg.bse
    cc_PC = (cc_reg.params - 1.96 * cc_bse < Beta_reg.flatten()) & \
            (cc_reg.params + 1.96 * cc_bse > Beta_reg.flatten())
    z_cc_error = np.sqrt(np.sum(np.square(cc_reg.predict() - Z)))/n

    return br_error, br_bse, z_br_error, cc_error,\
        cc_bse, z_cc_error, br_reg.params, cc_reg.params, br_PC*1, cc_PC*1


def run_models_final_split(
                    corr_type: int=2, test_split=0.4):
    """
    Function that creates the MC simulations
    It support different test splits and correlation types
    corr_type = 0, no correlation among labels
    corr_type = 1, up to 1 label will be dependent on another
    corr_type = 2, up to 2 labels will depend on previous labels
    """
    np.random.seed(2)
    n = 1000  # number of obs
    error_magnitude = 1
    correlation = 0 # correlation for the noise
    beta_mag = 5
    nsims = 1000 # number of simulations

    # generating Xs from poisson distributions
    x0 = np.ones(n)
    x1 = np.random.poisson(1, n)
    x2 = np.random.poisson(0.25, n)
    x3 = np.random.poisson(0.1, n)
    x4 = np.random.poisson(2, n)
    x5 = np.random.poisson(0.5, n)

    X = np.column_stack([x0, x1, x2, x3, x4, x5]) 
    num_x = X.shape[1]

    # beta generation
    By1 = np.random.randn(num_x)*beta_mag
    By2 = np.random.randn(num_x)*beta_mag
    By3 = np.random.randn(num_x)*beta_mag
    By4 = np.random.randn(num_x)*beta_mag
    By5 = np.random.randn(num_x)*beta_mag
   
    br_score = []
    cc_score = []

    reg_error_dict = {'CC_error_multiple': [], 'CC_error_single': [],
                      'BR_error_multiple': [], 'BR_error_single': [],
                      'BR_error_glm': [], 'CC_error_glm': []}

    cc_mse_list = []
    cc_bse_list = []
    cc_beta_list = []
    cc_pc_list = []

    br_beta_list = []
    br_mse_list = []
    br_bse_list = []
    br_pc_list = []

    print(f'Running {nsims} simulations with {100*test_split}% testing data')

    for _ in range(nsims):
        vcv = correlation * np.ones((5, 5))
        np.fill_diagonal(vcv, 1)

        e = np.random.multivariate_normal(
            [0, 0, 0, 0, 0], cov=vcv, size=n)*error_magnitude

        # median + noise is used as cutpoint, this is done to avoid class imbalance
        ystar1 = np.dot(X, By1) + e[:, 0]
        y1 = (ystar1 > np.median(ystar1) + np.random.randn(1)/10)*1
    
        if corr_type != 0:
            X = np.column_stack([x0, x1, x2, x3, x4, y1])

        ystar2 = np.dot(X, By2) + e[:, 1]
        y2 = (ystar2 > np.median(ystar2) + np.random.randn(1)/10)*1

        if corr_type == 0:
            X = np.column_stack([x0, x1, x2, x3, x4, x5])
        elif corr_type == 1:
            X = np.column_stack([x0, x1, x2, x3, x4, y1])
        else:
            X = np.column_stack([x0, x1, x2, x3, y1, y2])

        ystar3 = np.dot(X, By3) + e[:, 2]
        y3 = (ystar3 > np.median(ystar3) + np.random.randn(1)/10)*1

        if corr_type == 0:
            X = np.column_stack([x0, x1, x2, x3, x4, x5])
        elif corr_type == 1:
            X = np.column_stack([x5, x1, x2, x0, x4, y2])
        elif corr_type == 2:
            X = np.column_stack([x5, x1, x2, x0, y3, y2])
        else:
            X = np.column_stack([x5, x1, x2, y1, y3, y2])

        ystar4 = np.dot(X, By4) + e[:, 3]
        y4 = (ystar4 > np.median(ystar4) + np.random.randn(1)/10)*1

        if corr_type == 0:
            X = np.column_stack([x0, x1, x2, x3, x4, x5])
        elif corr_type == 1:
            X = np.column_stack([x0, x1, x2, x5, x4, y1])
        elif corr_type == 2:
            X = np.column_stack([x0, x1, x2, x5, y3, y1])
        else:
            X = np.column_stack([x0, x1, x2, y4, y3, y1])
        
        ystar5 = np.dot(X, By5) + e[:, 4]
        y5 = (ystar5 > np.median(ystar5) + np.random.randn(1)/10)*1
        
        Y = np.column_stack([y1, y2, y3, y4, y5])
        X = np.column_stack([x0, x1, x2, x3, x4, x5])

        Ytrain = Y[:-int(test_split*Y.shape[0])]
        Ytest = Y[-int(test_split*Y.shape[0]):]
        Xtrain = X[:-int(test_split*Y.shape[0])]
        std = StandardScaler()
        Xtrain = std.fit_transform(Xtrain)
        Xtest = X[-int(test_split*Y.shape[0]):]
        Xtest = std.transform(Xtest)

        br = BinaryRelevance(LogisticRegression(solver='saga', max_iter=2500))
        br.fit(Xtrain, Ytrain)
              
        outputs_list = []

        # ECC implementation
        for k in range(5):
            permute = np.random.permutation(Y.shape[1])
            reorder = np.argsort(permute)

            model = ClassifierChain(LogisticRegression(
                    solver='saga', max_iter=2500))

            Y_train = Ytrain[:, permute]

            model.fit(Xtrain, Y_train)
            pred_probs = model.predict_proba(Xtest).A

            outputs_list.append(pred_probs[:, reorder])

        ens_out = np.mean(outputs_list, axis=0)
        cc_predict = 1 * (ens_out > 0.5)
        br_predict = br.predict(Xtest).A

        Y_br = np.concatenate((Ytrain, br_predict))
        Y_cc = np.concatenate((Ytrain, cc_predict))
      
        Y_tuple = (y1, y2, y3, y4, y5)
        Y_reg = (Y_br, Y_cc)

        reg_types = ['single', 'multiple', 'glm']
        for reg_type in reg_types:
            br_error, br_bse, z_br_error, \
            cc_error, cc_bse, z_cc_error, \
            br_beta, cc_beta, br_pc, cc_pc = prepare_regressions(
                                Y_tuple, Y_reg, reg_type=reg_type)

            br_error['reg_type'] = reg_type
            br_bse['reg_type'] = reg_type

            cc_error['reg_type'] = reg_type
            cc_bse['reg_type'] = reg_type

            br_beta['reg_type'] = reg_type
            cc_beta['reg_type'] = reg_type

            br_pc['reg_type'] = reg_type
            cc_pc['reg_type'] = reg_type

            br_mse_list.append(br_error)
            br_bse_list.append(br_bse)
            br_beta_list.append(br_beta)
            br_pc_list.append(br_pc)

            # CC
            cc_mse_list.append(cc_error)
            # standard error
            cc_bse_list.append(cc_bse)
            cc_beta_list.append(cc_beta)
            cc_pc_list.append(cc_pc)
          
            reg_error_dict['CC_error_'+reg_type].append(np.sqrt(z_cc_error))
            reg_error_dict['BR_error_'+reg_type].append(np.sqrt(z_br_error))


        br_score.append({'BR Accuracy': metrics.accuracy_score(Ytest, br_predict),
            'BR Hamming': metrics.hamming_loss(Ytest, br_predict),
            'BR F1 Macro': metrics.f1_score(Ytest, br_predict, average='macro'),
            'BR F1 Micro': metrics.f1_score(Ytest, br_predict, average='micro'),
            'BR Ranking Loss': metrics.label_ranking_loss(Ytest, br.predict_proba(Xtest).A)
                        })
        cc_score.append({'ECC Accuracy':metrics.accuracy_score(Ytest,cc_predict),
            'ECC Hamming':metrics.hamming_loss(Ytest,cc_predict),
            'ECC F1 Macro':metrics.f1_score(Ytest,cc_predict,average='macro'),
            'ECC F1 Micro':metrics.f1_score(Ytest,cc_predict,average='micro'),
            'ECC Ranking Loss':metrics.label_ranking_loss(Ytest,ens_out)
                        })
        


    temp_br = pd.DataFrame(br_score)
    temp_cc = pd.DataFrame(cc_score)

    
    temp_br['test_split'] = test_split
    temp_cc['test_split'] = test_split


    cc_mse = pd.DataFrame(cc_mse_list)
    cc_mse['test_split'] = test_split

    cc_bse = pd.DataFrame(cc_bse_list)
    cc_bse['test_split'] = test_split

    br_mse = pd.DataFrame(br_mse_list)
    br_mse['test_split'] = test_split

    br_bse = pd.DataFrame(br_bse_list)
    br_bse['test_split'] = test_split

    br_beta = pd.DataFrame(br_beta_list)
    br_beta['test_split'] = test_split
    cc_beta = pd.DataFrame(cc_beta_list)
    cc_beta['test_split'] = test_split

    br_pc = pd.DataFrame(br_pc_list)
    br_pc['test_split'] = test_split
    cc_pc = pd.DataFrame(cc_pc_list)
    cc_pc['test_split'] = test_split


    print('Single regression Z Mean error = {:.4f} for BR and {:.4f} for CC'.format(
        np.mean(reg_error_dict['BR_error_single']),np.mean(reg_error_dict['CC_error_single'])))
    print('Multiple regression Z Mean error = {:.4f} for BR and {:.4f} for CC'.format(
        np.mean(reg_error_dict['BR_error_multiple']),np.mean(reg_error_dict['CC_error_multiple'])))
    print('GLM Z Mean error = {:.4f} for BR and {:.4f} for CC'.format(
        np.mean(reg_error_dict['BR_error_glm']),np.mean(reg_error_dict['CC_error_glm'])))

    return temp_br,temp_cc,cc_mse,cc_bse,br_mse,br_bse,br_beta,cc_beta,br_pc,cc_pc


def main():
    data_splits = [0.15, 0.30, 0.45, 0.60]
    cc_list = []
    br_list = []

    cc_mse_list_final = []
    cc_bse_list_final = []

    br_mse_list_final = []
    br_bse_list_final = []

    cc_beta_list_final = []
    br_beta_list_final = []

    cc_pc_list_final = []
    br_pc_list_final = []

    for split in data_splits:
        br,cc,cc_mse,cc_bse,br_mse,br_bse,br_beta,cc_beta,br_pc,cc_pc = run_models_final_split(corr_type=2,test_split=split)
        cc_list.append(cc)
        br_list.append(br)
        cc_mse_list_final.append(cc_mse)
        cc_bse_list_final.append(cc_bse)
        br_mse_list_final.append(br_mse)
        br_bse_list_final.append(br_bse)
        cc_beta_list_final.append(cc_beta)
        br_beta_list_final.append(br_beta)
        cc_pc_list_final.append(cc_pc)
        br_pc_list_final.append(br_pc)

    pd.concat(br_list).to_csv('../results/MC Classification/BR_classification.csv')
    pd.concat(cc_list).to_csv('../results/MC Classification/ECC_classification.csv')
    pd.concat(cc_mse_list_final).to_csv('../results/MC Regression/ECC_MSE.csv')
    pd.concat(cc_bse_list_final).to_csv('../results/MC Regression/ECC_BSE.csv')
    pd.concat(cc_beta_list_final).to_csv('../results/MC Regression/ECC_beta.csv')
    pd.concat(br_mse_list_final).to_csv('../results/MC Regression/BR_MSE.csv')
    pd.concat(br_bse_list_final).to_csv('../results/MC Regression/BR_BSE.csv')
    pd.concat(br_beta_list_final).to_csv('../results/MC Regression/BR_beta.csv')
    pd.concat(cc_pc_list_final).to_csv('../results/MC Regression/ECC_CP.csv')
    pd.concat(br_pc_list_final).to_csv('../results/MC Regression/BR_CP.csv')

if __name__ == "__main__":
    main()